# -*- coding: UTF-8 -*-

import numpy as np
import matplotlib.pyplot as plt

import pickle


def network_predict(input, r, w1, w2):
    total_w = (r + w1 * w2)**5
    res = np.outer(total_w, input)
    return res


def derive_w1_graidient(w1, w2, r):
    input = np.linspace(1., 1.8, 9)
    label = input * 50.
    predict = network_predict(input, r, w1, w2)
    gradient = 1. * (predict - label)

    new_w = w2 * (r + w1 * w2)**4
    new_w = np.reshape(new_w, [-1])
    new_w = new_w[:, np.newaxis]

    input = input[np.newaxis, :]
    w1_gradient = 5 * input * gradient
    w1_gradient = new_w * w1_gradient
    return w1_gradient


def log_norm2(gradient):
    gradient_norm2 = np.linalg.norm(gradient, axis=1)
    gradient_log_norm2 = np.log10(gradient_norm2)
    gradient_log_norm2 = gradient_log_norm2.reshape(100, 100)
    return gradient_log_norm2


if __name__ == '__main__':
    plt.rcParams['text.usetex'] = True
    plt.rcParams['savefig.dpi'] = 300
    plt.rcParams['figure.dpi'] = 300

    w1 = np.linspace(-1.3, 5.3, 100)
    w2 = np.linspace(-0., 1.8, 100)
    r = 1.

    W1, W2 = np.meshgrid(w1, w2)
    Z = log_norm2(derive_w1_graidient(W1, W2, r))

    fig = plt.figure()
    ax = fig.gca()

    cp = ax.contourf(W1, W2, Z, 100, cmap="RdBu",)
    cbar = fig.colorbar(cp, )

    ax.set_xlabel('$w_1$', fontsize=28)
    ax.set_ylabel('$w_2$', fontsize=28)
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)

    C = plt.contour(W1,W2,Z, colors="black",
                linestyles="dashed",
                linewidths=1,
                )

    with open("./data/residual01.pkl", "rb") as f:
        converge_data = pickle.load(f)
    w1s = converge_data["w1s"]
    w2s = converge_data["w2s"]

    plot_num = 170
    plt.plot(w1s[0:plot_num], w2s[0:plot_num], "r<--", markevery=[-1])
    plt.plot(w1s[0:plot_num], w2s[0:plot_num], "ro--", markevery=[0])

    with open("./data/residual10.pkl", "rb") as f:
        converge_data = pickle.load(f)
    w1s = converge_data["w1s"]
    w2s = converge_data["w2s"]

    plot_num = 170
    plt.plot(w1s[0:plot_num], w2s[0:plot_num], "r<--", markevery=[-1])
    plt.plot(w1s[0:plot_num], w2s[0:plot_num], "ro--", markevery=[0])

    fig.tight_layout()
    plt.savefig('./save/residual_isometry.pdf')
    plt.show()
